{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Simple Mistral 7B RAG with Weaviate Vector DB"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/simple-rag-1-db.webp\" width=700>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) Load the embedding model and LLM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62ecf205130442478b85c95584795ef8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "915b5afefaa540d181c2a074f10abaf1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5b81f87017a48aaacd54e76fbe78692",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/94.8k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f201eb94b1e54614b19429145e12b2e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb5acc3349d14cadb43957fd44c47f60",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2af75f35285443f0bff8de82d8e4bcf8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa7e2eb4733d460eab123ee181280057",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62b0bb0ba2da448faf384f786f80423b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5d923cd168f4ad09e441a7352412412",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65f6646d3a434e6084966d84b6cb121f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f30e54fbbd214360b795c5862a487d41",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c9ebd85387f4e12b5766d3223494241",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu and disk.\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core import Settings, PromptTemplate\n",
    "from llama_index.core.embeddings import resolve_embed_model\n",
    "from llama_index.embeddings.huggingface import HuggingFaceEmbedding\n",
    "from llama_index.llms.huggingface import HuggingFaceLLM\n",
    "\n",
    "import torch\n",
    "\n",
    "Settings.embed_model = HuggingFaceEmbedding(model_name=\"BAAI/bge-small-en-v1.5\")\n",
    "\n",
    "Settings.llm = HuggingFaceLLM(\n",
    "    context_window=2048,\n",
    "    max_new_tokens=256,\n",
    "    generate_kwargs={\"temperature\": 0.25, \"do_sample\": False},\n",
    "    model_name=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
    "    tokenizer_name=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
    "    device_map=\"auto\",\n",
    "    \n",
    "    model_kwargs={\n",
    "        \"torch_dtype\": torch.bfloat16,\n",
    "        # Since we are using a small GPU with limited memory\n",
    "        # Set to False if you have a large GPU to speed things up.\n",
    "        \"offload_buffers\": True, \n",
    "    }\n",
    ")\n",
    "\n",
    "Settings.chunk_size = 512"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2) Set up vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import weaviate\n",
    "from llama_index.vector_stores.weaviate import WeaviateVectorStore\n",
    "from llama_index.core import StorageContext\n",
    "\n",
    "resource_owner_config = weaviate.AuthClientPassword(\n",
    "    username=\"abc\",\n",
    "    password=\"xyz\",\n",
    ")\n",
    "\n",
    "client = weaviate.Client(\n",
    "    \"https://my-test-3lb8837z.weaviate.network\",\n",
    "    auth_client_secret=resource_owner_config,\n",
    ")\n",
    "\n",
    "vector_store = WeaviateVectorStore(\n",
    "    weaviate_client=client, index_name=\"LlamaIndex\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3) Load data / read in documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read documents: {'Basic-Scientific-Food-Preparation-Lab-Manual.txt', 'Basic-Scientific-Food-Preparation-Lab-Manual.pdf'}\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
    "\n",
    "data_dir =  Path(\"..\") / \"data\" / \"example-1\"\n",
    "\n",
    "documents = SimpleDirectoryReader(data_dir).load_data()\n",
    "\n",
    "# Sanity checks\n",
    "unique_docs = set(d.metadata[\"file_name\"] for d in documents)\n",
    "print(f\"Read documents: {unique_docs}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This populates the database and may take a few minutes\n",
    "storage_context = StorageContext.from_defaults(vector_store=vector_store)\n",
    "index = VectorStoreIndex.from_documents(\n",
    "    documents, storage_context=storage_context\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4) Set up custom prompt template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import PromptTemplate\n",
    "\n",
    "template = (\n",
    "    \"We have provided context information below. \\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"{context_str}\"\n",
    "    \"\\n---------------------\\n\"\n",
    "    \"Given this information, please answer the question: {query_str}\\n\"\n",
    ")\n",
    "qa_template = PromptTemplate(template)\n",
    "query_engine = index.as_query_engine()\n",
    "query_engine.update_prompts(\n",
    "    {\"response_synthesizer:text_qa_template\": qa_template}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5) Query the vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.25` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"What is dehydration?\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Answer: 2 tbsp. butter are required for the Cream Puffs.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"How many tbsp. butter to use for the Cream Puffs?\")\n",
    "print(response)\n",
    "# Correct answer is 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Answer: 2 tbsp. butter are required for the Cream Puffs.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"How many tbsp. butter to use for the Cream Puffs?\")\n",
    "print(response)\n",
    "# Correct answer is 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "The toppings for the Braided Bread of the Braids, Coffeecake, and Sweet Rolls recipe are:\n",
      "\n",
      "1. 2 tsp. caraway seeds and ½ cup shredded Cheddar cheese.\n",
      "2. ½ cup diced Swiss cheese and paprika.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\n",
    "    \"What are the toppings for the Braided Bread \"\n",
    "    \"of the Braids, Coffeecake, and Sweet Rolls recipe?\"\n",
    ")\n",
    "print(response)\n",
    "# Correct answer is \n",
    "# 2 tsp. caraway seeds and ½ cup shredded Cheddar cheese.\n",
    "# ½ cup diced Swiss cheese and paprika.\n",
    "# ..."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
